import torch
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.colors import Normalize

def scatter(x, y, Z, trajectories, args):
    X, Y = np.meshgrid(x, y)
    norm = Normalize(vmin=args.vmin, vmax=args.vmax)  # 线性规范化
    plt.figure(figsize=(8, 6))
    scatter_plot = plt.scatter(X.flatten(), Y.flatten(), c=Z.flatten(), cmap='viridis', s=1, norm=norm)
    plt.colorbar(scatter_plot, label='Z values')
    
    if trajectories:
        line_styles = ['-', '--', '-.', ':']  
        colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']   
        style_cycle = cycle(line_styles)
        color_cycle = cycle(colors)
        for name, path in trajectories.items():
            path = np.array(path)
            plt.plot(path[0], path[1], linestyle=next(style_cycle), color=next(color_cycle), label=f'{name}', alpha=0.9)
    plt.title('2D Heatmap Scatter Plot')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    plt.legend(loc='lower left', fontsize=6, bbox_to_anchor=(0, 0), frameon=False, handletextpad=0.5,  # 增加每个图例标签与图例边界的间距
    labelspacing=0.5)
    plt.savefig(f'../image/{args.task_name}/{args.id}_scatter_{args.xmin}_{args.xmax}_{args.xnum}_{args.ymin}_{args.ymax}_{args.ynum}_{args.vmin}_{args.vmax}_{args.vlevel}.png')  
